#=
计算hscfg的分布
=#


include("../src/qmcmary.jl")
using ..qmcmary

using Dates
using LinearAlgebra


##
##[0.9320867696990618, -0.9320867696990618, 0.9320867696990618, -0.9320867696990618]
##[0.36223508078590083, 0.36223508078590083, 0.36223508078590083, 0.36223508078590083]
##
function chai(L, Nt)
    fname = "hscfg1311"#ARGS[1]
    fid = open(fname, "r")
    the = parse.(Float64, split(readline(fid)[2:end-1], ','))
    println(the)
    #Nx = 2*L^2
    Nx = L
    nx = zeros(Nx)
    ny = sin.(the)
    nz = cos.(the)
    println(ny)
    println(nz)
    ##生成所有的配置
    hslen = Nt*L
    hspool = ["-1", "+1"]
    for idx in Base.OneTo(hslen-1)
        hspool1 = hspool
        hspool2 = copy(hspool)
        for idx in Base.OneTo(length(hspool1))
            hspool1[idx] = hspool1[idx]*"-1"
        end
        for idx in Base.OneTo(length(hspool2))
            hspool2[idx] = hspool2[idx]*"+1"
        end
        hspool = vcat(hspool1, hspool2)
    end
    println(hspool)
    wgtpool = zeros(Float64, 2^hslen)
    ##计算每一个配置的weight
    hk = lattice_chain(ComplexF64, L, -1.0+0.0im)
    engs = eigvals(hk)
    println(engs)
    ##
    Ui = 4*ones(Nx)
    Ng = 3
    hsdict = Dict()
    sp = default_splitting3(Nt, hk, Ui; Z2=true)
    for hi in Base.OneTo(2^hslen)
        hs = hspool[hi]
        hsdict[hs] = hi
        hscfg = Matrix{Int}(undef, Nt, Nx)
        for bi in Base.OneTo(Nt); for xi in Base.OneTo(Nx)
            sta = (bi-1)*Nx*2+2*xi-1
            hscfg[bi, xi] = parse(Int, hs[sta:sta+1])
        end; end
        println(hs, hscfg)
        sslen, bmats, allbmats, ss = initialize_SS(Nt, Ng, sp, hscfg, nx, ny, nz)
        wgt = det(I(2) + prod(reverse(allbmats)))
        wgtpool[hi] = abs(wgt)
    end
    println(wgtpool)
    println(wgtpool/sum(wgtpool))
    #rhss = ["-1+1-1+1", "-1+1+1+1", "+1+1-1+1", "+1+1+1+1"]
    #vhss = map(x->wgtpool[hsdict[x]], rhss)
    #println(vhss)
    #println(vhss/sum(vhss))
    #
    hscfg = Matrix{Int}(undef, Nt, Nx)
    #统计每个出现的频率
    freqpool = zeros(Float64, 2^hslen)
    #
    bistr = "0"
    while !eof(fid)
        bistr = readline(fid)
        #println(bistr)
        hsstr = ""
        for bi in Base.OneTo(Nt)
            cfg = readline(fid)
            cfg = split(cfg[2:end-1], ',')
            for xi in cfg
                hsstr = hsstr * (xi == "1" ? "+1" : "-1")
            end
        end
        #println(hsstr)
        ###
        freqpool[hsdict[hsstr]] += 1
    end
    #println(bistr)
    nsp = parse(Int, bistr)
    println(freqpool)
    println(freqpool/nsp)
end


function sgn_scratch(ss::ScrollSVD{T}) where T
    siz = size(ss.B[end])
    ptr = findfirst(ss.L)
    if isnothing(ptr)
        VL = Diagonal(ones(siz[1]))
        DL = VL
        UL = VL
        UR, DR, VR = ss.F[end].U, Diagonal(ss.F[end].S), ss.F[end].Vt
    elseif ptr == 1
        VL, DL, UL = ss.F[1].U, Diagonal(ss.F[1].S), ss.F[1].Vt
        UR = Diagonal(ones(siz[1]))
        DR = UR
        VR = UR
        ##上面这个数值稳定性会突然出问题
        ##用下面这个
        ##VL DL UL UR=I DR=I VR=I -> I I I UR=VL DR=DL VR=UL
        #VL = Diagonal(ones(siz[1]))
        #DL = VL
        #UL = VL
        #UR, DR, VR = ss.F[1].U, Diagonal(ss.F[1].S), ss.F[1].Vt
    else
        VL, DL, UL = ss.F[ptr].U, Diagonal(ss.F[ptr].S), ss.F[ptr].Vt
        UR, DR, VR = ss.F[ptr-1].U, Diagonal(ss.F[ptr-1].S), ss.F[ptr-1].Vt
    end
    #gtt = inv(Diagonal(ones(siz[1]))+UR*DR*VR*VL*DL*UL)
    #M = inv(UL*UR) + DR*(VR*VL)*DL
    #Fm = svd(M)
    #gtt = inv(Fm.Vt*UL)*inv(Diagonal(Fm.S))*inv(UR*Fm.U)
    #
    DLS = Diagonal(ones(Float64, siz[1]))
    DLB = Diagonal(ones(Float64, siz[1]))
    DRS = Diagonal(ones(Float64, siz[1]))
    DRB = Diagonal(ones(Float64, siz[1]))
    for i in Base.OneTo(siz[1])
        if DL[i, i] > 1.0
            DLB[i, i] = DL[i, i]
            DLS[i, i] = 1.0
        else
            DLS[i, i] = DL[i, i]
            DLB[i, i] = 1.0
        end
        if DR[i, i] > 1.0
            DRB[i, i] = DR[i, i]
            DRS[i, i] = 1.0
        else
            DRS[i, i] = DR[i, i]
            DRB[i, i] = 1.0
        end
    end
    #
    M = inv(DRB)*adjoint(UL*UR)*inv(DLB) + DRS*(VR*VL)*DLS
    Fm = svd(M, alg=LinearAlgebra.QRIteration())
    #ML = adjoint(UL)*inv(DLB)*adjoint(Fm.Vt)
    #MR = adjoint(Fm.U)*inv(DRB)*adjoint(UR)
    #gtt = ML*inv(Diagonal(Fm.S))*MR
    #增加计算phase
    sgn = det(Fm.Vt*UL)*det(UR*Fm.U)
    sgn = log(abs(sgn))
    for sval in Fm.S
        sgn += log(sval)
    end
    for sval in diag(DLB)
        sgn += log(sval)
    end
    for sval in diag(DRB)
        sgn += log(sval)
    end
    return sgn
end


function step_theta1(L, fname, the)
    fid = open(fname, "r")
    Nx = 2*L^2
    the2 = readline(fid)
    #println(the2)
    psi = zeros(Nx)
    #
    nx = sin.(psi).*sin.(the)
    ny = cos.(psi).*sin.(the)
    nz = cos.(the)
    #println(ny)
    #println(nz)
    #
    hk = lattice_hexagonal(ComplexF64, L, -1.0+0.0im; λ=sqrt(3)+0.0im)
    Ui = 10*ones(Nx)
    Nt = 60
    Ng = 4
    #重新读取cfg
    ttbar = zeros(Nx)
    tpbar = zeros(Nx)
    tph = 0.0
    tlogabs = 0.0
    nup_i = zeros(2*Nx)
    bistr = "0"
    sp = default_splitting2(Nt, hk, Ui; Z2=false)
    tapemap = pbpn_calc_tapemap(sp, nx, ny, nz)
    while !eof(fid)
        bistr = readline(fid)
        #println(bistr)
        hscfg = Matrix{Int}(undef, Nt, Nx)
        for bi in Base.OneTo(Nt)
            cfg = readline(fid)
            cfg = split(cfg[2:end-1], ',')
            for xi in Base.OneTo(Nx)
                hscfg[bi, xi] = strip(cfg[xi]) == "1" ? +1 : -1
            end
        end
        sslen, bmats, allbmats, ss = initialize_SS(Nt, Ng, sp, hscfg, nx, ny, nz)
        ###
        gf, ph = eq_green_scratch(ss)
        sgn = -sgn_scratch(ss)
        tlogabs += sgn
        #println(ph)
        tph += ph
        nup_i += diag(gf)
        #
        pnxpt = sin.(psi) .* cos.(the)
        pnypt = cos.(psi) .* cos.(the)
        pnzpt = -sin.(the)
        pnxpp = cos.(psi) .* sin.(the)
        pnypp = -sin.(psi) .* sin.(the)
        pnzpp = zeros(Nx)
        #
        nxbar, nybar, nzbar = meas_grad(ss, allbmats, sp, hscfg, nx, ny, nz; tapemap=tapemap)
        thetabar = @. nxbar * pnxpt + nybar * pnypt + nzbar * pnzpt
        psibar = @. nxbar * pnxpp + nybar * pnypp + nzbar * pnzpp 
        ttbar += real(thetabar)
        tpbar += real(psibar)
    end
    #
    close(fid)
    nsp = parse(Int, bistr)
    rightnow = Dates.now()
    fil = open("info", "a")
    #write(fil, string(rightnow)*"\n")
    #注意这个导数是-log(S)的导数
    #write(fil, string(the)*" "*string(ttbar/nsp)*"\n")
    #write(fil, string(tph/nsp)*"\n")
    #write(fil, string(nup_i/nsp)*"\n")
    println(tph/nsp, " ", tlogabs/nsp)
    return ttbar/nsp, tpbar/nsp
end

function optimal_theta1(L)
    fname = "hscfg1430"#ARGS[1]
    fid = open(fname, "r")
    Nx = 2*L^2
    #
    the = readline(fid)
    the = split(the[2:end-1], ',')
    the = parse.(Float64, the)
    println(the)
    close(fid)
    psi = zeros(Nx)
    tt = 0
    tm = zeros(Nx)
    tv = zeros(Nx)
    pt = 0
    pm = zeros(Nx)
    pv = zeros(Nx)
    for i in Base.OneTo(10)
        tg, pg = step_theta1(L, fname, the)
        ltg, tm, tv, tt = next(Adam, tg, tm, tv, tt)
        lpg, pm, pv, pt = next(Adam, pg, pm, pv, pt)
        the += ltg
        psi += lpg
        println(the)
        println(psi)
    end
    #tb = step_theta1(L, fname, the)
    #the -= 0.1*tb
    #tb = step_theta1(L, fname, the)
end


chai(1, 3)
#optimal_theta1(2)
